{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "(tir_blitz)=\n",
        "# TensorIR 的突击课程\n",
        "\n",
        "**作者**: [Siyuan Feng](https://github.com/Hzfengsy)\n",
        "\n",
        "TensorIR 是用于深度学习程序的特定域语言，有两个广泛的目的：\n",
        "\n",
        "- 在各种硬件后端进行程序变换和优化的实现。\n",
        "- 用于自动张量化程序优化的抽象。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import tvm\n",
        "from tvm.script import ir as I\n",
        "from tvm.ir.module import IRModule\n",
        "from tvm.script import tir as T\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## IRModule\n",
        "\n",
        "IRModule 是 TVM 的中心数据结构，它包含深度学习程序。它是 IR 变换和模型构建的基本关注对象。\n",
        "\n",
        "```{image} https://daobook.github.io/tvm-web-data/images/design/tvm_life_of_irmodule.png\n",
        ":width: 85%\n",
        "```\n",
        "\n",
        "这是 IRModule 的生命周期（life cycle），它可以从 TVMScript 创建。TensorIR 调度原语（primitive）和传递（pass）是变换 IRModule 的两种主要方式。另外，对 IRModule 进行一系列的变换也是可以接受的。请注意，可以在 **任何** 阶段向 TVMScript 打印 IRModule。在所有变换和优化完成后，可以将 IRModule 构建为可运行的模块，以部署在目标设备上。\n",
        "\n",
        "基于 TensorIR 和 IRModule 的设计，能够创建新的编程方式：\n",
        "\n",
        "1. 用 TVMScript 写基于 Python-AST 语法的程序。\n",
        "2. 用 python api 变换和优化程序。\n",
        "3. 通过命令式的变换 API，交互式地检查和尝试性能。\n",
        "\n",
        "## 创建 IRModule\n",
        "\n",
        "IRModule 可以通过编写 TVMScript 来创建，TVMScript 是 TVM IR 的可圆润化（round-trippable）的语法。\n",
        "\n",
        "与通过 [张量表达式](tutorial-tensor-expr-get-started) 创建计算表达式不同，TensorIR 允许用户通过 TVMScript（嵌入式 python AST 的语言）来编程。这种新方法使得编写复杂的程序并进一步调度和优化它成为可能。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.ir.module.IRModule'>\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">8</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;B&quot;</span>):\n",
              "                vi <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>spatial(<span style=\"color: #008000\">8</span>, i)\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi])\n",
              "                B[vi] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1</span>)\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "@I.ir_module\n",
        "class MyModule:\n",
        "    @T.prim_func\n",
        "    def main(A: T.Buffer(8, \"float32\"), B: T.Buffer(8, \"float32\")):\n",
        "        T.func_attr({\"global_symbol\": \"main\", \"tir.noalias\": True})\n",
        "        for i in range(8):\n",
        "            # block 是计算的抽象。\n",
        "            with T.block(\"B\"):\n",
        "                # 定义 spatial block 迭代器，并将其绑定到值 i。\n",
        "                vi = T.axis.spatial(8, i)\n",
        "                B[vi] = A[vi] + 1.0\n",
        "\n",
        "\n",
        "ir_module = MyModule\n",
        "print(type(ir_module))\n",
        "ir_module.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "此外，还可以使用张量表达式 DSL 来编写简单的算子，并将其转换为 IRModule。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i0 <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">8</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;B&quot;</span>):\n",
              "                v_i0 <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>spatial(<span style=\"color: #008000\">8</span>, i0)\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[v_i0])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[v_i0])\n",
              "                B[v_i0] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[v_i0] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1</span>)\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from tvm import te\n",
        "\n",
        "A = te.placeholder((8,), dtype=\"float32\", name=\"A\")\n",
        "B = te.compute((8,), lambda *i: A(*i) + 1.0, name=\"B\")\n",
        "func = te.create_prim_func([A, B])\n",
        "ir_module_from_te = IRModule({\"main\": func})\n",
        "ir_module_from_te.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 构建和运行 IRModule\n",
        "\n",
        "我们可以将 IRModule 构建为具有特定目标后端的可运行模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.driver.build_module.OperatorModule'>\n"
          ]
        }
      ],
      "source": [
        "mod = tvm.build(ir_module, target=\"llvm\")  # 用于 CPU 后端的模块\n",
        "print(type(mod))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "准备好输入 array 和输出 array，然后运行该模块。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[0. 1. 2. 3. 4. 5. 6. 7.]\n",
            "[1. 2. 3. 4. 5. 6. 7. 8.]\n"
          ]
        }
      ],
      "source": [
        "a = tvm.nd.array(np.arange(8).astype(\"float32\"))\n",
        "b = tvm.nd.array(np.zeros((8,)).astype(\"float32\"))\n",
        "mod(a, b)\n",
        "print(a)\n",
        "print(b)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## 转换 IRModule\n",
        "\n",
        "IRModule 是程序优化的中心数据结构，它可以通过 `Schedule` 进行转换。调度包含多个原语方法，以交互式地转换程序。每个原语都以某些方式改造程序，以带来额外的性能优化。\n",
        "\n",
        "<img src=\"https://daobook.github.io/tvm-web-data/images/design/tvm_tensor_ir_opt_flow.png\" align=\"center\" width=\"100%\">\n",
        "\n",
        "上面的图片是优化张量程序的典型工作流程。首先，需要在由 TVMScript 或 Tensor Expression 创建的初始 IRModule 上创建调度。然后，一连串的调度原语将有助于提高性能。最后，我们可以将其降低并构建为可运行的模块。\n",
        "\n",
        "这里只演示了非常简单的变换。首先，在输入的 `ir_module` 上创建调度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<class 'tvm.tir.schedule.schedule.Schedule'>\n"
          ]
        }
      ],
      "source": [
        "sch = tvm.tir.Schedule(ir_module)\n",
        "print(type(sch))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "将该循环分为 3 个循环，并打印结果。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i_0, i_1, i_2 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;B&quot;</span>):\n",
              "                vi <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>spatial(<span style=\"color: #008000\">8</span>, i_0 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_2)\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi])\n",
              "                B[vi] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1</span>)\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# 按 name 获取 block\n",
        "block_b = sch.get_block(\"B\")\n",
        "# 获取 block 周围的 loops\n",
        "(i,) = sch.get_loops(block_b)\n",
        "# 平铺（tile）循环嵌套。\n",
        "i_0, i_1, i_2 = sch.split(i, factors=[None, 2, 2])\n",
        "sch.mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "也可以重新调度循环的顺序。现在将循环 `i_2` 移到 `i_1` 的外面。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i_0, i_2, i_1 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>grid(<span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>, <span style=\"color: #008000\">2</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;B&quot;</span>):\n",
              "                vi <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>spatial(<span style=\"color: #008000\">8</span>, i_0 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_2)\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi])\n",
              "                T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi])\n",
              "                B[vi] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1</span>)\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "sch.reorder(i_0, i_2, i_1)\n",
        "sch.mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### 转化为 GPU 程序\n",
        "\n",
        "如果想在 GPU 上部署模型，线程绑定是必要的。幸运的是，也可以使用原语并做增量变换。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "collapsed": false
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
              "<span style=\"color: #007979; font-style: italic\"># from tvm.script import tir as T</span>\n",
              "\n",
              "\n",
              "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
              "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
              "    <span style=\"color: #AA22FF\">@T</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>prim_func\n",
              "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(A: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>), B: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>Buffer((<span style=\"color: #008000\">8</span>,), <span style=\"color: #BA2121\">&quot;float32&quot;</span>)):\n",
              "        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>func_attr({<span style=\"color: #BA2121\">&quot;global_symbol&quot;</span>: <span style=\"color: #BA2121\">&quot;main&quot;</span>, <span style=\"color: #BA2121\">&quot;tir.noalias&quot;</span>: T<span style=\"color: #AA22FF; font-weight: bold\">.</span>bool(<span style=\"color: #008000; font-weight: bold\">True</span>)})\n",
              "        <span style=\"color: #007979; font-style: italic\"># with T.block(&quot;root&quot;):</span>\n",
              "        <span style=\"color: #008000; font-weight: bold\">for</span> i_0 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>thread_binding(<span style=\"color: #008000\">2</span>, thread<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;blockIdx.x&quot;</span>):\n",
              "            <span style=\"color: #008000; font-weight: bold\">for</span> i_2 <span style=\"color: #008000; font-weight: bold\">in</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>thread_binding(<span style=\"color: #008000\">2</span>, thread<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;threadIdx.x&quot;</span>):\n",
              "                <span style=\"color: #008000; font-weight: bold\">for</span> i_1 <span style=\"color: #008000; font-weight: bold\">in</span> range(<span style=\"color: #008000\">2</span>):\n",
              "                    <span style=\"color: #008000; font-weight: bold\">with</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>block(<span style=\"color: #BA2121\">&quot;B&quot;</span>):\n",
              "                        vi <span style=\"color: #AA22FF; font-weight: bold\">=</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>axis<span style=\"color: #AA22FF; font-weight: bold\">.</span>spatial(<span style=\"color: #008000\">8</span>, i_0 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">4</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_1 <span style=\"color: #AA22FF; font-weight: bold\">*</span> <span style=\"color: #008000\">2</span> <span style=\"color: #AA22FF; font-weight: bold\">+</span> i_2)\n",
              "                        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>reads(A[vi])\n",
              "                        T<span style=\"color: #AA22FF; font-weight: bold\">.</span>writes(B[vi])\n",
              "                        B[vi] <span style=\"color: #AA22FF; font-weight: bold\">=</span> A[vi] <span style=\"color: #AA22FF; font-weight: bold\">+</span> T<span style=\"color: #AA22FF; font-weight: bold\">.</span>float32(<span style=\"color: #008000\">1</span>)\n",
              "</pre></div>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "sch.bind(i_0, \"blockIdx.x\")\n",
        "sch.bind(i_2, \"threadIdx.x\")\n",
        "sch.mod.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "绑定线程后，现在用 `cuda` 后端构建 IRModule。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "ctx = tvm.cuda(0)\n",
        "cuda_mod = tvm.build(sch.mod, target=\"cuda\")\n",
        "cuda_a = tvm.nd.array(np.arange(8).astype(\"float32\"), ctx)\n",
        "cuda_b = tvm.nd.array(np.zeros((8,)).astype(\"float32\"), ctx)\n",
        "cuda_mod(cuda_a, cuda_b)\n",
        "print(cuda_a)\n",
        "print(cuda_b)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3.8.13 ('py38': conda)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.10"
    },
    "vscode": {
      "interpreter": {
        "hash": "28558e8daad512806f5c536a1a04c119185f99f65b79002708a12162d02a79c7"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
